Pixel Level Calibration¶
There are a lot of existing frameworks for phase retrieval - but it is not so straightforward to simultaneously retrieve
- aberrations (the optical distortions, maybe tens of parameters)
- astrometry (the positions of stars, tens of parameters)
- interpixel sensitivity (aka the 'flat field', on a large pixel grid!)
The dimensionality of the pixel grid can be so high it is hard to retrieve without autodiff. With dLux, it is easy*!
* This is the topic of a Desdoigts et al paper in prep. It wasn't quite that easy to build and we're pretty happy about it.
First, import everything as usual:
# Core jax
import jax
import jax.numpy as np
import jax.random as jr
# Optimisation
import zodiax as zdx
import optax
# Optics
import dLux as dl
from dLux.utils import arcseconds_to_radians as a2r
from dLux.utils import radians_to_arcseconds as r2a
# Plotting/visualisation
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
%matplotlib inline
plt.rcParams['image.cmap'] = 'inferno'
plt.rcParams["font.family"] = "serif"
plt.rcParams["image.origin"] = 'lower'
plt.rcParams['figure.dpi'] = 120
dLux: Jax is running in 32-bit, to enable 64-bit visit: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
Then we generate an array of 5 dithered images, in a coarsely-sampled broad bandwidth, of 15 stars. The reason we want multiple stars, and especially dithered images, is because we want to retrieve the flat field - multiplicative errors on each pixel that, if we only have the one image, are formally degenerate with the PSF. It turns out that if we can hit each pixel with light from the same PSF multiple times, we can retrieve the whole thing unambiguously.
We will use the presaved default Toliman pupil - a diffractive pupil that nicely spreads out the PSF in a highly structured pattern, designed (with dLux) to be ideal for astrometry and field calibration.
We will then add some Zernike mode aberrations on top of this, and have some 5% level flat field calibration errors, and additive background noise.
# Basic Optical Parameters
diameter = 0.5
wf_npix = 256
# Detector Parameters
det_npix = 256
det_pixsize = a2r(10/det_npix)
# Load mask
raw_mask = np.load("files/test_mask.npy") * (6e-7/(2*np.pi))
mask = dl.utils.scale_array(raw_mask, wf_npix, 0)
# Zernike Basis
zern_basis = np.arange(3, 10)
coeffs = 2e-8 * jr.normal(jr.PRNGKey(0), [len(zern_basis)])
# Define Optical Configuration
optical_layers = [
dl.CreateWavefront (wf_npix, diameter),
dl.ApertureFactory (wf_npix, zernikes=zern_basis, coefficients=coeffs),
dl.AddOPD (mask),
dl.NormaliseWavefront (),
dl.AngularMFT (det_npix, det_pixsize)]
# Create Optics object
optics = dl.Optics(optical_layers)
# Pixel response
pix_response = 1 + 0.05*jr.normal(jr.PRNGKey(0), [det_npix, det_npix])
# Create Detector object
detector = dl.Detector([dl.ApplyPixelResponse(pix_response)])
# Multiple sources to observe
Nstars = 15
true_positions = a2r(jr.uniform(jr.PRNGKey(0), (Nstars, 2), minval=-4, maxval=4))
true_fluxes = 1e8 + 1e7*jr.normal(jr.PRNGKey(0), (Nstars,))
wavels = 1e-9 * np.linspace(545, 645, 3)
# Create Source object
source = dl.MultiPointSource(true_positions, true_fluxes, wavelengths=wavels)
Now we need to introduce the dithers. To do this we define an 'observation function' that we use to update the relevant parameters and model the sources. Instruments have a pre-built dither_and_model(dithers) function that does this for us!
With the observation function, we then put this inside of a dictionary under the key 'fn' along with the dithers under 'args'. We can then use the .obserse() method to call the function stored under 'fn' with the input arguments 'args'. This is how we allow for arbitrary observation stratergy to be modelled simply!
# Observation stratergy, define dithers
dithers = det_pixsize * np.array([[0, 0], [+1, +1], [+1, -1], [-1, +1], [-1, -1]])
observation = dl.Dither(dithers)
# def observe_fn(model, dithers):
# return model.dither_and_model(dithers)
# # Observation dictionary
# observation = {'fn': observe_fn, 'args': dithers}
Now we combine all of this into an instrument and observe!
# Combine into instrument
tel = dl.Instrument(optics=optics, sources=[source], detector=detector,
observation=observation)
# Observe!
psfs = tel.observe()
# Apply some noise to the PSF Background noise
BG_noise = np.abs(5*jr.normal(jr.PRNGKey(0), psfs.shape))
data = jr.poisson(jr.PRNGKey(0), psfs) + BG_noise
plt.figure(figsize=(25, 4))
for i in range(len(psfs)):
plt.subplot(1, 5, i+1)
plt.imshow(data[i])
plt.colorbar()
plt.show()
Now there are 4 sets of parameters we are going to learn:
- Positions
- Fluxes
- Zernike aberrations
- Pixel responses
We start by defining the paths to those parameters. We will define them individually so we can refer to them easily later
positions = 'MultiPointSource.position'
fluxes = 'MultiPointSource.flux'
zernikes = 'CircularAperture.coefficients'
flatfield = 'ApplyPixelResponse.pixel_response'
parameters = [positions, fluxes, zernikes, flatfield]
Each of these parameters needs a different initilisation
- Positions need to be shifted by some random value
- Fluxes need to be multiplied by some random value
- Zernike coefficients need to be set to zero
- Pixel response values need to be set to one
Perturb the values to intialise the model
# Add small random values to the positions
model = tel.add(positions, 2.5*det_pixsize*jr.normal(jr.PRNGKey(0), (Nstars, 2)))
# Multiply the fluxes by small random values
model = model.multiply(fluxes, 1 + 0.1*jr.normal(jr.PRNGKey(0), (Nstars,)))
# Set the zernike coefficients to zero
model = model.set(zernikes, np.zeros(len(zern_basis)))
# Set the flat fiel to uniform
model = model.set(flatfield, np.ones((det_npix, det_npix)))
# Generate psfs
psfs = model.observe()
Model and observe the residuals are pretty bad:
plt.figure(figsize=(25, 4))
for i in range(len(psfs)):
plt.subplot(1, 5, i+1)
plt.imshow(psfs[i] - data[i])
plt.colorbar()
plt.show()
Now we want to generate an optax optimiser object that we can use to train each parameter individually. Becuase of the various scales and effect of the loss fucntion that each parameter has, we need to be able to set individual learning rates, and optimisation schedules for every parameter. Luckily we have built some functions to help specifically with that! Lets see how to use it!
# So first we simply set the simple parameters to use an adam optimiser
# algorithm, with individual learning rates
pos_optimiser = optax.adam(2e-8)
flux_optimiser = optax.adam(1e6)
coeff_optimiser = optax.adam(2e-9)
# Now the flat-field, becuase it is highly covariant with the mean flux level
# we don't start learning its parameters until the 100th epoch.
FF_sched = optax.piecewise_constant_schedule(init_value=1e-2*1e-8,
boundaries_and_scales={100 : int(1e8)})
FF_optimiser = optax.adam(FF_sched)
# Combine the optimisers into a list
optimisers = [pos_optimiser, flux_optimiser, coeff_optimiser, FF_optimiser]
# Generate out optax optimiser, and also get our args
optim, opt_state, args = model.get_optimiser(parameters, optimisers, get_args=True)
Poisson log-likelihood:
@zdx.filter_jit
@zdx.filter_value_and_grad(parameters)
def loss_fn(model, data):
out = model.observe()
return -np.sum(jax.scipy.stats.poisson.logpmf(data, out))
Call once to jit compile:
%%time
loss, grads = loss_fn(model, data) # Compile
print("Initial Loss: {}".format(int(loss)))
Initial Loss: 1270765568 CPU times: user 3.88 s, sys: 113 ms, total: 3.99 s Wall time: 911 ms
Run gradient descent:
losses, models_out = [], []
with tqdm(range(200),desc='Gradient Descent') as t:
for i in t:
loss, grads = loss_fn(model, data)
updates, opt_state = optim.update(grads, opt_state)
model = model.apply_updates(updates)
losses.append(loss)
models_out.append(model)
t.set_description("Log Loss: {:.3f}".format(np.log10(loss))) # update the progress bar
Gradient Descent: 0%| | 0/200 [00:00<?, ?it/s]
Format the output into arrays:
nepochs = len(models_out)
psfs_out = models_out[-1].observe()
positions_found = np.array([model.get(positions) for model in models_out])
fluxes_found = np.array([model.get(fluxes) for model in models_out])
zernikes_found = np.array([model.get(zernikes) for model in models_out])
flatfields_found = np.array([model.get(flatfield) for model in models_out])
Pull out the quantities to be plotted - eg final model and residuals:
coeff_residuals = coeffs - zernikes_found
flux_residuals = true_fluxes - fluxes_found
scaler = 1e3
positions_residuals = true_positions - positions_found
r_residuals_rads = np.hypot(positions_residuals[:, :, 0], positions_residuals[:, :, 1])
r_residuals = r2a(r_residuals_rads)
j = len(models_out)
plt.figure(figsize=(16, 13))
plt.subplot(3, 2, 1)
plt.title("Log10 Loss")
plt.xlabel("Epochs")
plt.ylabel("Log10 ADU")
plt.plot(np.log10(np.array(losses)[:j]))
plt.subplot(3, 2, 2)
plt.title("Stellar Positions")
plt.xlabel("Epochs")
plt.ylabel("Positional Error (arcseconds)")
plt.plot(r_residuals[:j])
plt.axhline(0, c='k', alpha=0.5)
plt.subplot(3, 2, 3)
plt.title("Stellar Fluxes")
plt.xlabel("Epochs")
plt.ylabel("Flux Error (Photons)")
plt.plot(flux_residuals[:j])
plt.axhline(0, c='k', alpha=0.5)
plt.subplot(3, 2, 4)
plt.title("Zernike Coeff Residuals")
plt.xlabel("Epochs")
plt.ylabel("Residual Amplitude")
plt.plot(coeff_residuals[:j])
plt.axhline(0, c='k', alpha=0.5)
plt.tight_layout()
plt.show()
How did the phase retrieval go? Really well, as it happens!
# OPDs
true_opd = tel.CircularAperture.get_opd()
opds_found = np.array([model.CircularAperture.get_opd() for model in models_out])
found_opd = opds_found[-1]
opd_residuls = true_opd - opds_found
opd_rmse_nm = 1e9*np.mean(opd_residuls**2, axis=(-1,-2))**0.5
vmin = np.min(np.array([true_opd, found_opd]))
vmax = np.max(np.array([true_opd, found_opd]))
# Coefficients
true_coeff = tel.get(zernikes)
found_coeff = models_out[-1].get(zernikes)
index = np.arange(len(true_coeff))+4
plt.figure(figsize=(20, 10))
plt.suptitle("Optical Aberrations")
plt.subplot(2, 2, 1)
plt.title("RMS OPD residual")
plt.xlabel("Epochs")
plt.ylabel("RMS OPD (nm)")
plt.plot(opd_rmse_nm)
plt.axhline(0, c='k', alpha=0.5)
plt.subplot(2, 2, 2)
plt.title("Zernike Coefficient Amplitude")
plt.xlabel("Index")
plt.ylabel("Amplitude")
plt.scatter(index, true_coeff, label="True Value")
plt.scatter(index, found_coeff, label="Recovered Value", marker='x')
plt.bar(index, true_coeff - found_coeff, label='Residual')
plt.axhline(0, c='k', alpha=0.5)
plt.legend()
plt.subplot(2, 3, 4)
plt.title("True OPD")
plt.imshow(true_opd)
plt.colorbar()
plt.subplot(2, 3, 5)
plt.title("Found OPD")
plt.imshow(found_opd)
plt.colorbar()
plt.subplot(2, 3, 6)
plt.title("OPD Residual")
plt.imshow(true_opd - found_opd, vmin=vmin, vmax=vmax)
plt.colorbar()
plt.show()
Most impressively, we are getting the tens of thousands of parameters of the flat field pretty well too!
# calculate the mask where there was enough flux to infer the flat field
thresh = 2500
fmask = data.mean(0) >= thresh
out_mask = np.where(data.mean(0) < thresh)
in_mask = np.where(data.mean(0) >= thresh)
data_tile = np.tile(data.mean(0), [len(models_out), 1, 1])
in_mask_tiled = np.where(data_tile >= thresh)
# calculate residuals
pr_residuals = pix_response[in_mask] - flatfields_found[-1][in_mask]
# for correlation plot
true_pr_masked = pix_response.at[out_mask].set(1)
found_pr_masked = flatfields_found[-1].at[out_mask].set(1)
# FF Scatter Plot
data_sum = data.sum(0) # [flux_mask]
colours = data_sum.flatten()
ind = np.argsort(colours)
colours = colours[ind]
pr_true_flat = true_pr_masked.flatten()
pr_found_flat = found_pr_masked.flatten()
pr_true_sort = pr_true_flat[ind]
pr_found_sort = pr_found_flat[ind]
# Errors
pfound = flatfields_found[in_mask_tiled].reshape([len(models_out), len(in_mask[0])])
ptrue = pix_response[in_mask]
pr_res = ptrue - pfound
masked_error = np.abs(pr_res).mean(-1)
plt.figure(figsize=(20, 10))
plt.subplot(2, 3, (1,2))
plt.title("Pixel Response")
plt.xlabel("Epochs")
plt.ylabel("Mean Sensitivity Error")
plt.plot(masked_error)
plt.axhline(0, c='k', alpha=0.5)
# FF Scatter Plot
data_sum = data.sum(0)
colours = data_sum.flatten()
ind = np.argsort(colours)
colours = colours[ind]
pr_true_flat = true_pr_masked.flatten()
pr_found_flat = found_pr_masked.flatten()
pr_true_sort = pr_true_flat[ind]
pr_found_sort = pr_found_flat[ind]
plt.subplot(2, 3, 3)
plt.plot(np.linspace(0.8, 1.2), np.linspace(0.8, 1.2), c='k', alpha=0.75)
plt.scatter(pr_true_sort, pr_found_sort, c=colours, alpha=0.5)
plt.colorbar()
plt.title("Sensitivity Residual")
plt.ylabel("Recovered Sensitivity")
plt.xlabel("True Sensitivity")
plt.subplot(2, 3, 4)
plt.title("True Pixel Response")
plt.xlabel("Pixels")
plt.ylabel("Pixels")
plt.imshow(true_pr_masked)
plt.colorbar()
vmin = np.min(pix_response)
vmax = np.max(pix_response)
plt.subplot(2, 3, 5)
plt.title("Found Pixel Response")
plt.xlabel("Pixels")
plt.ylabel("Pixels")
plt.imshow(found_pr_masked, vmin=vmin, vmax=vmax)
plt.colorbar()
plt.subplot(2, 3, 6)
plt.title("Pixel Response Residual")
plt.xlabel("Pixels")
plt.ylabel("Pixels")
plt.imshow(true_pr_masked - found_pr_masked, vmin=-0.2, vmax=0.2)
plt.colorbar()
plt.show()